import os
import os.path as osp
import random
import subprocess
import time
from functools import partial
from typing import Any, Dict, List, Optional, Tuple

import mmengine
from mmengine.config import ConfigDict
from mmengine.utils import track_parallel_progress

from opencompass.registry import RUNNERS, TASKS
from opencompass.utils import get_logger

from .base import BaseRunner


@RUNNERS.register_module()
class SlurmRunner(BaseRunner):
    """Distributed runner based on Slurm. It will launch tasks in parallel
    using `srun` command.

    Args:
        task (ConfigDict): Task type config.
        max_num_workers (int): Max number of workers to run in parallel.
            Defaults to 32.
        retry (int): Number of retries if the job failed. Defaults to 2.
        partition (str): Slurm partition name. Defaults to None.
        quotatype (str): Slurm quota type. Defaults to None.
        qos (str): Slurm quality of service. Defaults to None.
        debug (bool): Whether to run in debug mode. Defaults to False.
        lark_bot_url (str): Lark bot url. Defaults to None.
        extra_command (List, optional): Extra slurm command.
            For example ['-c 12', '-w node1']. Defaults to None.
    """

    def __init__(self,
                 task: ConfigDict,
                 max_num_workers: int = 32,
                 retry: int = 2,
                 partition: str = None,
                 quotatype: str = None,
                 qos: str = None,
                 debug: bool = False,
                 lark_bot_url: str = None,
                 extra_command: Optional[List[str]] = None):
        super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url)
        self.max_num_workers = max_num_workers
        self.retry = retry
        self.partition = partition
        self.quotatype = quotatype
        self.qos = qos
        if not extra_command:
            extra_command = []
        assert isinstance(extra_command, list)
        self.extra_command = extra_command

    def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
        """Launch multiple tasks.

        Args:
            tasks (list[dict]): A list of task configs, usually generated by
                Partitioner.

        Returns:
            list[tuple[str, int]]: A list of (task name, exit code).
        """

        if not self.debug:
            status = track_parallel_progress(self._launch,
                                             tasks,
                                             nproc=self.max_num_workers,
                                             keep_order=False)
        else:
            status = [self._launch(task, random_sleep=False) for task in tasks]
        return status

    def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
        """Launch a single task.

        Args:
            cfg (ConfigDict): Task config.
            random_sleep (bool): Whether to sleep for a random time before
                running the command. This avoids cluster error when launching
                multiple tasks at the same time. Default: True.

        Returns:
            tuple[str, int]: Task name and exit code.
        """
        task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
        num_gpus = task.num_gpus
        task_name = task.name

        # Dump task config to file
        mmengine.mkdir_or_exist('tmp/')
        param_file = f'tmp/{os.getpid()}_params.py'
        try:
            cfg.dump(param_file)

            # Build up slurm command
            tmpl = 'srun'
            if self.partition:
                tmpl += f' -p {self.partition}'
            if self.quotatype:
                tmpl += f' --quotatype={self.quotatype}'
            if self.qos:
                tmpl += f' --qos={self.qos}'
            if num_gpus > 0:
                tmpl += f' --gres=gpu:{num_gpus}'
            for extra_cmd in self.extra_command:
                tmpl += f' {extra_cmd}'
            tmpl += f" -N1 -u -J '{task_name[:512]}'" + ' {task_cmd}'
            get_cmd = partial(task.get_command,
                              cfg_path=param_file,
                              template=tmpl)
            cmd = get_cmd()

            logger = get_logger()
            logger.debug(f'Running command: {cmd}')

            # Run command with retry
            if self.debug:
                stdout = None
            else:
                out_path = task.get_log_path(file_extension='out')
                mmengine.mkdir_or_exist(osp.split(out_path)[0])
                stdout = open(out_path, 'w', encoding='utf-8')

            if random_sleep:
                time.sleep(random.randint(0, 10))
            result = subprocess.run(cmd,
                                    shell=True,
                                    text=True,
                                    stdout=stdout,
                                    stderr=stdout)

            retry = self.retry
            output_paths = task.get_output_paths()
            while self._job_failed(result.returncode,
                                   output_paths) and retry > 0:
                retry -= 1
                if random_sleep:
                    time.sleep(random.randint(0, 10))
                # Re-generate command to refresh ports.
                cmd = get_cmd()
                result = subprocess.run(cmd,
                                        shell=True,
                                        text=True,
                                        stdout=stdout,
                                        stderr=stdout)

            if result.returncode != 0 and not self.debug:
                logger.error(f'task {task_name} fail, see\n{out_path}')
        finally:
            # Clean up
            os.remove(param_file)
        return task_name, result.returncode

    def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
        return return_code != 0 or not all(
            osp.exists(output_path) for output_path in output_paths)
